/*
 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
 *
 * SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
 */

use crate::{
    core::Session,
    inbound::DkimSign,
    queue::{MessageSource, MessageWrapper, spool::SmtpSpool},
};
use common::{
    Server, USER_AGENT,
    config::smtp::report::{AddressMatch, AggregateFrequency},
    expr::if_block::IfBlock,
    ipc::ReportingEvent,
};
use mail_auth::{
    common::headers::HeaderWriter,
    report::{AuthFailureType, DeliveryResult, Feedback, FeedbackType},
};
use mail_parser::DateTime;
use std::{future::Future, io, time::SystemTime};
use store::write::{ReportEvent, key::KeySerializer};
use tokio::io::{AsyncRead, AsyncWrite};

pub mod analysis;
pub mod dkim;
pub mod dmarc;
pub mod scheduler;
pub mod spf;
pub mod tls;

impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> {
    pub fn new_auth_failure(&self, ft: AuthFailureType, rejected: bool) -> Feedback<'_> {
        Feedback::new(FeedbackType::AuthFailure)
            .with_auth_failure(ft)
            .with_arrival_date(
                SystemTime::now()
                    .duration_since(SystemTime::UNIX_EPOCH)
                    .map_or(0, |d| d.as_secs()) as i64,
            )
            .with_source_ip(self.data.remote_ip)
            .with_reporting_mta(&self.hostname)
            .with_user_agent(USER_AGENT)
            .with_delivery_result(if rejected {
                DeliveryResult::Reject
            } else {
                DeliveryResult::Unspecified
            })
    }

    pub fn is_report(&self) -> bool {
        for addr_match in &self.server.core.smtp.report.analysis.addresses {
            for addr in &self.data.rcpt_to {
                match addr_match {
                    AddressMatch::StartsWith(prefix) if addr.address_lcase.starts_with(prefix) => {
                        return true;
                    }
                    AddressMatch::EndsWith(suffix) if addr.address_lcase.ends_with(suffix) => {
                        return true;
                    }
                    AddressMatch::Equals(value) if addr.address_lcase.eq(value) => return true,
                    _ => (),
                }
            }
        }

        false
    }
}

pub trait SmtpReporting: Sync + Send {
    fn send_report(
        &self,
        from_addr: &str,
        rcpts: impl Iterator<Item = impl AsRef<str> + Sync + Send> + Sync + Send,
        report: Vec<u8>,
        sign_config: &IfBlock,
        deliver_now: bool,
        parent_session_id: u64,
    ) -> impl Future<Output = ()> + Send;

    fn send_autogenerated(
        &self,
        from_addr: impl AsRef<str> + Sync + Send,
        rcpts: impl Iterator<Item = impl AsRef<str> + Sync + Send> + Sync + Send,
        raw_message: Vec<u8>,
        sign_config: Option<&IfBlock>,
        parent_session_id: u64,
    ) -> impl Future<Output = ()> + Send;

    fn schedule_report(
        &self,
        report: impl Into<ReportingEvent> + Sync + Send,
    ) -> impl Future<Output = ()> + Send;

    fn sign_message(
        &self,
        message: &mut MessageWrapper,
        config: &IfBlock,
        bytes: &[u8],
    ) -> impl Future<Output = Option<Vec<u8>>> + Send;
}

impl SmtpReporting for Server {
    async fn send_report(
        &self,
        from_addr: &str,
        rcpts: impl Iterator<Item = impl AsRef<str> + Sync + Send> + Sync + Send,
        report: Vec<u8>,
        sign_config: &IfBlock,
        deliver_now: bool,
        parent_session_id: u64,
    ) {
        // Build message
        let mut message = self.new_message(from_addr, parent_session_id);
        for rcpt_ in rcpts {
            message.add_recipient(rcpt_.as_ref(), self).await;
        }

        // Sign message
        let signature = self.sign_message(&mut message, sign_config, &report).await;

        // Schedule delivery at a random time between now and the next 3 hours
        if !deliver_now {
            #[cfg(not(feature = "test_mode"))]
            {
                use common::config::smtp::queue::QueueExpiry;
                use rand::Rng;

                let delivery_time = rand::rng().random_range(0u64..10800u64);
                for rcpt in &mut message.message.recipients {
                    rcpt.retry.due += delivery_time;
                    rcpt.notify.due += delivery_time;
                    if let QueueExpiry::Ttl(expires) = &mut rcpt.expires {
                        *expires += delivery_time;
                    }
                }
            }
        }

        // Queue message
        message
            .queue(
                signature.as_deref(),
                &report,
                parent_session_id,
                self,
                MessageSource::Report,
            )
            .await;
    }

    async fn send_autogenerated(
        &self,
        from_addr: impl AsRef<str> + Sync + Send,
        rcpts: impl Iterator<Item = impl AsRef<str> + Sync + Send> + Sync + Send,
        raw_message: Vec<u8>,
        sign_config: Option<&IfBlock>,
        parent_session_id: u64,
    ) {
        // Build message
        let mut message = self.new_message(from_addr.as_ref(), parent_session_id);
        for rcpt in rcpts {
            message.add_recipient(rcpt, self).await;
        }

        // Sign message
        let signature = if let Some(sign_config) = sign_config {
            self.sign_message(&mut message, sign_config, &raw_message)
                .await
        } else {
            None
        };

        // Queue message
        message
            .queue(
                signature.as_deref(),
                &raw_message,
                parent_session_id,
                self,
                MessageSource::Autogenerated,
            )
            .await;
    }

    async fn schedule_report(&self, report: impl Into<ReportingEvent> + Sync + Send) {
        if self.inner.ipc.report_tx.send(report.into()).await.is_err() {
            trc::event!(
                Server(trc::ServerEvent::ThreadError),
                CausedBy = trc::location!(),
                Details = "Failed to send event to ReportScheduler"
            );
        }
    }

    async fn sign_message(
        &self,
        message: &mut MessageWrapper,
        config: &IfBlock,
        bytes: &[u8],
    ) -> Option<Vec<u8>> {
        let signers = self
            .eval_if::<Vec<String>, _>(config, &message.message, message.span_id)
            .await
            .unwrap_or_default();
        if !signers.is_empty() {
            let mut headers = Vec::with_capacity(64);
            for signer in signers.iter() {
                if let Some(signer) = self.get_dkim_signer(signer, message.span_id) {
                    match signer.sign(bytes) {
                        Ok(signature) => {
                            signature.write_header(&mut headers);
                        }
                        Err(err) => {
                            trc::error!(
                                trc::Error::from(err)
                                    .span_id(message.span_id)
                                    .details("Failed to sign message")
                                    .caused_by(trc::location!())
                            );
                        }
                    }
                }
            }
            if !headers.is_empty() {
                return Some(headers);
            }
        }
        None
    }
}

pub trait AggregateTimestamp {
    fn to_timestamp(&self) -> u64;
    fn to_timestamp_(&self, dt: DateTime) -> u64;
    fn as_secs(&self) -> u64;
    fn due(&self) -> u64;
}

impl AggregateTimestamp for AggregateFrequency {
    fn to_timestamp(&self) -> u64 {
        self.to_timestamp_(DateTime::from_timestamp(
            SystemTime::now()
                .duration_since(SystemTime::UNIX_EPOCH)
                .map_or(0, |d| d.as_secs()) as i64,
        ))
    }

    fn to_timestamp_(&self, mut dt: DateTime) -> u64 {
        (match self {
            AggregateFrequency::Hourly => {
                dt.minute = 0;
                dt.second = 0;
                dt.to_timestamp()
            }
            AggregateFrequency::Daily => {
                dt.hour = 0;
                dt.minute = 0;
                dt.second = 0;
                dt.to_timestamp()
            }
            AggregateFrequency::Weekly => {
                let dow = dt.day_of_week();
                dt.hour = 0;
                dt.minute = 0;
                dt.second = 0;
                dt.to_timestamp() - (86400 * dow as i64)
            }
            AggregateFrequency::Never => dt.to_timestamp(),
        }) as u64
    }

    fn as_secs(&self) -> u64 {
        match self {
            AggregateFrequency::Hourly => 3600,
            AggregateFrequency::Daily => 86400,
            AggregateFrequency::Weekly => 7 * 86400,
            AggregateFrequency::Never => 0,
        }
    }

    fn due(&self) -> u64 {
        self.to_timestamp() + self.as_secs()
    }
}

pub struct SerializedSize {
    bytes_left: usize,
}

impl SerializedSize {
    pub fn new(bytes_left: usize) -> Self {
        Self { bytes_left }
    }
}

impl io::Write for SerializedSize {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        //let c = print!(" (left: {}, buf: {})", self.bytes_left, buf.len());
        let buf_len = buf.len();
        if buf_len <= self.bytes_left {
            self.bytes_left -= buf_len;
            Ok(buf_len)
        } else {
            Err(io::Error::other("Size exceeded"))
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        Ok(())
    }
}

pub trait ReportLock {
    fn tls_lock(&self) -> Vec<u8>;
    fn dmarc_lock(&self) -> Vec<u8>;
}

impl ReportLock for ReportEvent {
    fn tls_lock(&self) -> Vec<u8> {
        KeySerializer::new(self.domain.len() + std::mem::size_of::<u64>() + 1)
            .write(0u8)
            .write(self.due)
            .write(self.domain.as_bytes())
            .finalize()
    }

    fn dmarc_lock(&self) -> Vec<u8> {
        KeySerializer::new(self.domain.len() + (std::mem::size_of::<u64>() * 2) + 1)
            .write(1u8)
            .write(self.due)
            .write(self.policy_hash)
            .write(self.domain.as_bytes())
            .finalize()
    }
}
